{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从一份良好的数据集中学习，训练过程中完全不更新数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20000,\n",
       " (tensor([0.9019, 0.4320, 0.8704]),\n",
       "  tensor([-1.5736]),\n",
       "  tensor([0.9651]),\n",
       "  tensor([0.8817, 0.4719, 0.8943]),\n",
       "  tensor([0])))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "#封装数据集\n",
    "class Dataset(torch.utils.data.Dataset):\n",
    "\n",
    "    def __init__(self):\n",
    "        import numpy as np\n",
    "        data = np.loadtxt('离线学习数据.txt')\n",
    "        self.state = torch.FloatTensor(data[:, :3]).reshape(-1, 3)\n",
    "        self.action = torch.FloatTensor(data[:, 3]).reshape(-1, 1)\n",
    "        self.reward = torch.FloatTensor(data[:, 4]).reshape(-1, 1)\n",
    "        self.next_state = torch.FloatTensor(data[:, 5:8]).reshape(-1, 3)\n",
    "        self.over = torch.LongTensor(data[:, 8]).reshape(-1, 1)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.state)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.state[i], self.action[i], self.reward[i], self.next_state[\n",
    "            i], self.over[i]\n",
    "\n",
    "\n",
    "dataset = Dataset()\n",
    "\n",
    "len(dataset), dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "312"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#数据加载器\n",
    "loader = torch.utils.data.DataLoader(dataset=dataset,\n",
    "                                     batch_size=64,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "len(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<__main__.CQL at 0x7fa42ea40af0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from sac import SAC\n",
    "\n",
    "\n",
    "#定义CQL算法\n",
    "class CQL(SAC):\n",
    "\n",
    "    def get_loss_cql(self, state, next_state, value):\n",
    "        #把state,next_state复制5遍\n",
    "        state = state.unsqueeze(dim=1).repeat(1, 5, 1).reshape(-1, 3)\n",
    "        next_state = next_state.unsqueeze(1).repeat(1, 5, 1).reshape(-1, 3)\n",
    "\n",
    "        #计算动作和熵\n",
    "        rand_action = torch.empty([len(state), 1]).uniform_(-1, 1)\n",
    "        curr_action, curr_entropy = self.get_action_entropy(state)\n",
    "        next_action, next_entropy = self.get_action_entropy(next_state)\n",
    "\n",
    "        #计算三份动作分别的value\n",
    "        value_rand = self.model_value(torch.cat([state, rand_action],\n",
    "                                                dim=1)).reshape(-1, 5, 1)\n",
    "        value_curr = self.model_value(torch.cat([state, curr_action],\n",
    "                                                dim=1)).reshape(-1, 5, 1)\n",
    "        value_next = self.model_value(torch.cat([state, next_action],\n",
    "                                                dim=1)).reshape(-1, 5, 1)\n",
    "\n",
    "        curr_entropy = curr_entropy.detach().reshape(-1, 5, 1)\n",
    "        next_entropy = next_entropy.detach().reshape(-1, 5, 1)\n",
    "\n",
    "        #三份value分别减去他们的熵\n",
    "        #-0.6931471805599453 = math.log(0.5)\n",
    "        value_rand -= -0.6931471805599453\n",
    "        value_curr -= curr_entropy\n",
    "        value_next -= next_entropy\n",
    "\n",
    "        #拼合三份value\n",
    "        value_cat = torch.cat([value_rand, value_curr, value_next], dim=1)\n",
    "        loss_cat = value_cat.exp().sum(dim=1).log().mean()\n",
    "\n",
    "        #在value loss上增加上这一部分\n",
    "        return 5.0 * (loss_cat - value.mean())\n",
    "\n",
    "\n",
    "cql = CQL()\n",
    "\n",
    "cql"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAdA0lEQVR4nO3de3DTdb7/8VfSXHr9Jm2hiZUW6oJil8tqwRKdc5izdKnacb30D9dhtONydNTAgDicsbuCs87OlB/OrKu7ijvjKJ4/tDt1trqyoNtTtOgYuVSqBbHrBWwsJKVAkt5yafL+/VH7XaJVCbT9NPJ6zHxn7Pf7SfIOmOek329TDCIiICKaYkbVAxDRxYnxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIlGB8iUoLxISIllMXn6aefxpw5c5CZmYnKykrs27dP1ShEpICS+Pz1r3/Fhg0b8Oijj+KDDz7A4sWLUV1djd7eXhXjEJECBhUfLK2srMTSpUvx5z//GQCQSCRQUlKCtWvX4uGHH/7B2ycSCRw/fhx5eXkwGAyTPS4RnSMRQX9/P4qLi2E0fv97G9MUzaSLRqNob29HfX29vs9oNKKqqgoej2fc20QiEUQiEf3rnp4elJeXT/qsRHR+vF4vZs2a9b1rpjw+fX19iMfjcDgcSfsdDgc++eSTcW/T0NCA3/3ud9/a7/V6oWnapMxJRKkLhUIoKSlBXl7eD66d8vicj/r6emzYsEH/euwJaprG+BBNQ+dyOmTK4zNjxgxkZGTA7/cn7ff7/XA6nePexmq1wmq1TsV4RDRFpvxql8ViQUVFBVpbW/V9iUQCra2tcLlcUz0OESmi5NuuDRs2oK6uDkuWLME111yDP/7xjxgcHMTdd9+tYhwiUkBJfG6//XacPHkSmzdvhs/nw89+9jO88cYb3zoJTUQ/Xkp+zudChUIh2Gw2BINBnnAmmkZSeW3ys11EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpATjQ0RKMD5EpETK8dmzZw9uuukmFBcXw2Aw4NVXX006LiLYvHkzLrnkEmRlZaGqqgqffvpp0prTp09j1apV0DQNdrsdq1evxsDAwAU9ESJKLynHZ3BwEIsXL8bTTz897vGtW7fiqaeewrPPPou9e/ciJycH1dXVCIfD+ppVq1bh8OHDaGlpwY4dO7Bnzx7ce++95/8siCj9yAUAIM3NzfrXiURCnE6nPP744/q+QCAgVqtVXn75ZRER+fjjjwWA7N+/X1+za9cuMRgM0tPTc06PGwwGBYAEg8ELGZ+IJlgqr80JPedz9OhR+Hw+VFVV6ftsNhsqKyvh8XgAAB6PB3a7HUuWLNHXVFVVwWg0Yu/evePebyQSQSgUStqIKL1NaHx8Ph8AwOFwJO13OBz6MZ/Ph6KioqTjJpMJBQUF+ppvamhogM1m07eSkpKJHJuIFEiLq1319fUIBoP65vV6VY9ERBdoQuPjdDoBAH6/P2m/3+/XjzmdTvT29iYdHxkZwenTp/U132S1WqFpWtJGROltQuNTVlYGp9OJ1tZWfV8oFMLevXvhcrkAAC6XC4FAAO3t7fqa3bt3I5FIoLKyciLHIaJpzJTqDQYGBvDZZ5/pXx89ehQdHR0oKChAaWkp1q9fj9///veYN28eysrKsGnTJhQXF+OWW24BAFx55ZW4/vrrcc899+DZZ59FLBbDmjVr8Ktf/QrFxcUT9sSIaJpL9VLaW2+9JQC+tdXV1YnI6OX2TZs2icPhEKvVKitWrJCurq6k+zh16pTccccdkpubK5qmyd133y39/f3nPAMvtRNNT6m8Ng0iIgrbd15CoRBsNhuCwSDP/xBNI6m8NtPiahcR/fgwPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREowPkSkBONDREqkFJ+GhgYsXboUeXl5KCoqwi233IKurq6kNeFwGG63G4WFhcjNzUVtbS38fn/Smu7ubtTU1CA7OxtFRUXYuHEjRkZGLvzZEFHaSCk+bW1tcLvdeP/999HS0oJYLIaVK1dicHBQX/Pggw/i9ddfR1NTE9ra2nD8+HHcdttt+vF4PI6amhpEo1G89957ePHFF7F9+3Zs3rx54p4VEU1/cgF6e3sFgLS1tYmISCAQELPZLE1NTfqaI0eOCADxeDwiIrJz504xGo3i8/n0Ndu2bRNN0yQSiZzT4waDQQEgwWDwQsYnogmWymvzgs75BINBAEBBQQEAoL29HbFYDFVVVfqa+fPno7S0FB6PBwDg8XiwcOFCOBwOfU11dTVCoRAOHz487uNEIhGEQqGkjYjS23nHJ5FIYP369bjuuuuwYMECAIDP54PFYoHdbk9a63A44PP59DVnh2fs+Nix8TQ0NMBms+lbSUnJ+Y5NRNPEecfH7Xbj0KFDaGxsnMh5xlVfX49gMKhvXq930h+TiCaX6XxutGbNGuzYsQN79uzBrFmz9P1OpxPRaBSBQCDp3Y/f74fT6dTX7Nu3L+n+xq6Gja35JqvVCqvVej6jEtE0ldI7HxHBmjVr0NzcjN27d6OsrCzpeEVFBcxmM1pbW/V9XV1d6O7uhsvlAgC4XC50dnait7dXX9PS0gJN01BeXn4hz4WI0khK73zcbjdeeuklvPbaa8jLy9PP0dhsNmRlZcFms2H16tXYsGEDCgoKoGka1q5dC5fLhWXLlgEAVq5cifLyctx5553YunUrfD4fHnnkEbjdbr67IbqYpHIZDcC42wsvvKCvGR4elgceeEDy8/MlOztbbr31Vjlx4kTS/Rw7dkxuuOEGycrKkhkzZshDDz0ksVjsnOfgpXai6SmV16ZBRERd+s5PKBSCzWZDMBiEpmmqx5kyZ/9VxQcGEDtzBub8fGTk5ur7DQaDitGIAKT22jyvE8409SSRwHB3N8688w6GvvgCsVOnED15EubCQlidTmRfdhmyLrsMVocDlpkzkZGdDXwdIgaJpiPGJw3IyAj6Wltx/KWXMHLmTNKx+OAgwt3dCO7bBxgMyMjJgclmg/WSS0aDNHs2rMXFsDocMGZmwmA0AgYDg0TKMT7TnIjgjMeDnu3bET/rM3TfsRjxgQHEBwYQ6elB6MABwGhERnY2MnJzkXnppciaMweZs2Yhq7R0NEhWKwxm82iUiKYQ4zPNjQSDOP7SSz8cnu+SSOhBivp8CLW3A0YjjFYrMrKzkVlaiqxZs5A1Zw6yZs+GZeZMGLOyRqPEd0c0iRifaW7gyBFEjh+f2DtNJJAYHkZieBixU6fQf/AgYDTCYDLBpGnIvPRSZBYXI+uyy5A1Zw4shYXIyM1lkGhCMT7TXSIBTMUFyUQCEo0i1teHWF8f+j/8cPTcUEYGTHY7rA4HMi+9FNk/+Qmy5syBubAQZpsNBotFvwuGiVLB+Ex3Kl/QIpCRET1IA1//1gFDRgZM+fmwFBYis7R0NEglJbDMnAlzQQEMJhOvtNEPYnymuZx582CZORPRkydVj6KTeFwP0mBXF061tMBgNsOkaTDn5yNr9mxklZUhc9YsZBYXw5yfPxoko5ExIh3jM82ZZ8yAo7YWXz3/PCQaVT3Od5JYDLFTpxA7dQpDn30GADCYzcjIzYVZ00bPH82ejazSUmReeilMdvvoVbaMDAbpIsX4THMGgwEzqqoQ7++Hv7kZ8aEh1SOdM4nFMHLmDEbOnMHwl18CAAwWC4xWKyyFhciaPRuZJSWjV9pKSka/ZbNYGKOLBOOTBowWCxy1tci54gqcam3FwJEjGBkYQCKNQjRGolHEo1EM9/dj+NgxAIDBZIIxMxO5P/0pim68EXmLFsGQkaF2UJp0/GxXGhERIJHAyMAAIn4/IsePY+iLLzD8+eeI9vUhduYMEuGw6jEvSEZ2Norvugszq6sZoDTEz3b9SBkMBiAjA2abDWabDTnz5qFg+XJABCOhEKJ9fYj09GDo888xdPQooidPItrXB4nFpuZyfYpEBIFoFJ/398NmseAneXnA0BB6/vd/YZk5E7YlS/gt2I8Y45PG9BemwQCz3Q6z3Y6cuXNRsHw5ZGQEI/39iJ05g7DXi8HPPkPY60XE50P05ElIPD76M0SKiAi6Bwex6eBBdAWDyDGZ8N+XX47by8qAoSH4m5uRt2ABMrKylM1Ik4vx+ZEymEww5+fDnJ+P7MsuQ/5//ickFkN8cBCxQADh7m4Mfv45Ij09GPZ6MRIIQGKx0ShNAQHw/zo78XEgAAAIxWL485Ej+KndjsUFBRj64gvEh4cZnx8xxuciYTAYRq80WSyjQSor04OUCIcRPX0a4e5uDB07hrDXi7DXi5FQCIlwGDJJ/5psKBZL+jqaSCAyRfEj9Rifi9jZQTJpGrLnzEH+f/wHZGQEiWgU0ZMnEe7pQfjLLzF09CjCX32FkVBo9HL/BUbCAOC/nE58Fgph5OvzUZdrGmZ//YvRMqxWftL+R47xoSQGgwEGsxlGsxmmnBxkzZ4NXHvt6IdRIxFEensROXECw8eOYeiLLxA5fhyxQADxgYGUTmobDAbUzZ2LPLMZ/3fiBC7JysI9l1+OosxMAED+8uUwXURXMi9GvNROKdP/lxFBfHgYsb4+RPz+0cv+X3wxelK7r2/014D8wP9eIgLB6DshYDRKueXlKPuf/4Hl638Jl9IHL7XTpDr7KpspJwemnBxklpbCtnTpaJC+Pqkd8fkw9PnnGD52DJETJxD1+xGPRJI+qW8wGP4dHrMZ2lVXYdbq1QzPRYDxoQmRFKS8PJjy8pBVUgLbkiWjv9BsaAgj/f0If/UVho8exXB3N8I9PcjIzkZ2WRmM2dnIvfJK5F55Ja9wXSQYH5pUYz8YORakzOJi2JYuhcTjSITDMHz9a17p4sP40JQzGAyjn+c665/8oYsPr2USkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRKMDxEpwfgQkRIpxWfbtm1YtGgRNE2DpmlwuVzYtWuXfjwcDsPtdqOwsBC5ubmora2F3+9Puo/u7m7U1NQgOzsbRUVF2LhxI0Ym6ReUE9H0lVJ8Zs2ahS1btqC9vR0HDhzAz3/+c9x88804fPgwAODBBx/E66+/jqamJrS1teH48eO47bbb9NvH43HU1NQgGo3ivffew4svvojt27dj8+bNE/usiGj6kwuUn58vzz33nAQCATGbzdLU1KQfO3LkiAAQj8cjIiI7d+4Uo9EoPp9PX7Nt2zbRNE0ikch3PkY4HJZgMKhvXq9XAEgwGLzQ8YloAgWDwXN+bZ73OZ94PI7GxkYMDg7C5XKhvb0dsVgMVVVV+pr58+ejtLQUHo8HAODxeLBw4UI4HA59TXV1NUKhkP7uaTwNDQ2w2Wz6VlJScr5jE9E0kXJ8Ojs7kZubC6vVivvuuw/Nzc0oLy+Hz+eDxWKB3W5PWu9wOODz+QAAPp8vKTxjx8eOfZf6+noEg0F983q9qY5NRNNMyr9G9YorrkBHRweCwSBeeeUV1NXVoa2tbTJm01mtVlit1kl9DCKaWinHx2KxYO7cuQCAiooK7N+/H08++SRuv/12RKNRBAKBpHc/fr8fTqcTAOB0OrFv376k+xu7Gja2hoguDhf8cz6JRAKRSAQVFRUwm81obW3Vj3V1daG7uxsulwsA4HK50NnZid7eXn1NS0sLNE1DeXn5hY5CRGkkpXc+9fX1uOGGG1BaWor+/n689NJLePvtt/Hmm2/CZrNh9erV2LBhAwoKCqBpGtauXQuXy4Vly5YBAFauXIny8nLceeed2Lp1K3w+Hx555BG43W5+W0V0kUkpPr29vbjrrrtw4sQJ2Gw2LFq0CG+++SZ+8YtfAACeeOIJGI1G1NbWIhKJoLq6Gs8884x++4yMDOzYsQP3338/XC4XcnJyUFdXh8cee2xinxURTXv8t9qJaMKk8trkZ7uISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISAnGh4iUYHyISIkLis+WLVtgMBiwfv16fV84HIbb7UZhYSFyc3NRW1sLv9+fdLvu7m7U1NQgOzsbRUVF2LhxI0ZGRi5kFCJKM+cdn/379+Mvf/kLFi1alLT/wQcfxOuvv46mpia0tbXh+PHjuO222/Tj8XgcNTU1iEajeO+99/Diiy9i+/bt2Lx58/k/CyJKP3Ie+vv7Zd68edLS0iLLly+XdevWiYhIIBAQs9ksTU1N+tojR44IAPF4PCIisnPnTjEajeLz+fQ127ZtE03TJBKJjPt44XBYgsGgvnm9XgEgwWDwfMYnokkSDAbP+bV5Xu983G43ampqUFVVlbS/vb0dsVgsaf/8+fNRWloKj8cDAPB4PFi4cCEcDoe+prq6GqFQCIcPHx738RoaGmCz2fStpKTkfMYmomkk5fg0Njbigw8+QENDw7eO+Xw+WCwW2O32pP0OhwM+n09fc3Z4xo6PHRtPfX09gsGgvnm93lTHJqJpxpTKYq/Xi3Xr1qGlpQWZmZmTNdO3WK1WWK3WKXs8Ipp8Kb3zaW9vR29vL66++mqYTCaYTCa0tbXhqaeegslkgsPhQDQaRSAQSLqd3++H0+kEADidzm9d/Rr7emwNEf34pRSfFStWoLOzEx0dHfq2ZMkSrFq1Sv9vs9mM1tZW/TZdXV3o7u6Gy+UCALhcLnR2dqK3t1df09LSAk3TUF5ePkFPi4imu5S+7crLy8OCBQuS9uXk5KCwsFDfv3r1amzYsAEFBQXQNA1r166Fy+XCsmXLAAArV65EeXk57rzzTmzduhU+nw+PPPII3G43v7UiuoikFJ9z8cQTT8BoNKK2thaRSATV1dV45pln9OMZGRnYsWMH7r//frhcLuTk5KCurg6PPfbYRI9CRNOYQURE9RCpCoVCsNlsCAaD0DRN9ThE9LVUXpv8bBcRKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKcH4EJESjA8RKWFSPcD5EBEAQCgUUjwJEZ1t7DU59hr9PmkZn1OnTgEASkpKFE9CROPp7++HzWb73jVpGZ+CggIAQHd39w8+wekmFAqhpKQEXq8XmqapHuecce6pla5ziwj6+/tRXFz8g2vTMj5G4+ipKpvNllZ/MWfTNC0tZ+fcUysd5z7XNwQ84UxESjA+RKREWsbHarXi0UcfhdVqVT1KytJ1ds49tdJ17lQY5FyuiRERTbC0fOdDROmP8SEiJRgfIlKC8SEiJRgfIlIiLePz9NNPY86cOcjMzERlZSX27dundJ49e/bgpptuQnFxMQwGA1599dWk4yKCzZs345JLLkFWVhaqqqrw6aefJq05ffo0Vq1aBU3TYLfbsXr1agwMDEzq3A0NDVi6dCny8vJQVFSEW265BV1dXUlrwuEw3G43CgsLkZubi9raWvj9/qQ13d3dqKmpQXZ2NoqKirBx40aMjIxM2tzbtm3DokWL9J/+dblc2LVr17SeeTxbtmyBwWDA+vXr0272CSFpprGxUSwWizz//PNy+PBhueeee8Rut4vf71c2086dO+W3v/2t/O1vfxMA0tzcnHR8y5YtYrPZ5NVXX5UPP/xQfvnLX0pZWZkMDw/ra66//npZvHixvP/++/LOO+/I3Llz5Y477pjUuaurq+WFF16QQ4cOSUdHh9x4441SWloqAwMD+pr77rtPSkpKpLW1VQ4cOCDLli2Ta6+9Vj8+MjIiCxYskKqqKjl48KDs3LlTZsyYIfX19ZM299///nf5xz/+If/617+kq6tLfvOb34jZbJZDhw5N25m/ad++fTJnzhxZtGiRrFu3Tt+fDrNPlLSLzzXXXCNut1v/Oh6PS3FxsTQ0NCic6t++GZ9EIiFOp1Mef/xxfV8gEBCr1Sovv/yyiIh8/PHHAkD279+vr9m1a5cYDAbp6emZstl7e3sFgLS1telzms1maWpq0tccOXJEAIjH4xGR0fAajUbx+Xz6mm3btommaRKJRKZs9vz8fHnuuefSYub+/n6ZN2+etLS0yPLly/X4pMPsEymtvu2KRqNob29HVVWVvs9oNKKqqgoej0fhZN/t6NGj8Pl8STPbbDZUVlbqM3s8HtjtdixZskRfU1VVBaPRiL17907ZrMFgEMC/f2tAe3s7YrFY0uzz589HaWlp0uwLFy6Ew+HQ11RXVyMUCuHw4cOTPnM8HkdjYyMGBwfhcrnSYma3242ampqkGYH0+POeSGn1qfa+vj7E4/GkP3gAcDgc+OSTTxRN9f18Ph8AjDvz2DGfz4eioqKk4yaTCQUFBfqayZZIJLB+/Xpcd911WLBggT6XxWKB3W7/3tnHe25jxyZLZ2cnXC4XwuEwcnNz0dzcjPLycnR0dEzbmQGgsbERH3zwAfbv3/+tY9P5z3sypFV8aPK43W4cOnQI7777rupRzskVV1yBjo4OBINBvPLKK6irq0NbW5vqsb6X1+vFunXr0NLSgszMTNXjKJdW33bNmDEDGRkZ3zr77/f74XQ6FU31/cbm+r6ZnU4nent7k46PjIzg9OnTU/K81qxZgx07duCtt97CrFmz9P1OpxPRaBSBQOB7Zx/vuY0dmywWiwVz585FRUUFGhoasHjxYjz55JPTeub29nb09vbi6quvhslkgslkQltbG5566imYTCY4HI5pO/tkSKv4WCwWVFRUoLW1Vd+XSCTQ2toKl8ulcLLvVlZWBqfTmTRzKBTC3r179ZldLhcCgQDa29v1Nbt370YikUBlZeWkzSYiWLNmDZqbm7F7926UlZUlHa+oqIDZbE6avaurC93d3Umzd3Z2JsWzpaUFmqahvLx80mb/pkQigUgkMq1nXrFiBTo7O9HR0aFvS5YswapVq/T/nq6zTwrVZ7xT1djYKFarVbZv3y4ff/yx3HvvvWK325PO/k+1/v5+OXjwoBw8eFAAyB/+8Ac5ePCgfPnllyIyeqndbrfLa6+9Jh999JHcfPPN415qv+qqq2Tv3r3y7rvvyrx58yb9Uvv9998vNptN3n77bTlx4oS+DQ0N6Wvuu+8+KS0tld27d8uBAwfE5XKJy+XSj49d+l25cqV0dHTIG2+8ITNnzpzUS78PP/ywtLW1ydGjR+Wjjz6Shx9+WAwGg/zzn/+ctjN/l7OvdqXb7Bcq7eIjIvKnP/1JSktLxWKxyDXXXCPvv/++0nneeustAfCtra6uTkRGL7dv2rRJHA6HWK1WWbFihXR1dSXdx6lTp+SOO+6Q3Nxc0TRN7r77bunv75/UucebGYC88MIL+prh4WF54IEHJD8/X7Kzs+XWW2+VEydOJN3PsWPH5IYbbpCsrCyZMWOGPPTQQxKLxSZt7l//+tcye/ZssVgsMnPmTFmxYoUenuk683f5ZnzSafYLxd/nQ0RKpNU5HyL68WB8iEgJxoeIlGB8iEgJxoeIlGB8iEgJxoeIlGB8iEgJxoeIlGB8iEgJxoeIlPj/gHWXSSWCc9kAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(\n",
    "            [action * 2])\n",
    "        over = terminated or truncated\n",
    "\n",
    "        #偏移reward,便于训练\n",
    "        reward = (reward + 8) / 8\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            over = True\n",
    "\n",
    "        return state, reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "65.49787483233098"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    data = []\n",
    "    reward_sum = 0\n",
    "\n",
    "    state = env.reset()\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据概率采样\n",
    "        mu, sigma = cql.model_action(torch.FloatTensor(state).reshape(1, 3))\n",
    "        action = random.normalvariate(mu=mu.item(), sigma=sigma.item())\n",
    "\n",
    "        next_state, reward, over = env.step(action)\n",
    "\n",
    "        data.append((state, action, reward, next_state, over))\n",
    "        reward_sum += reward\n",
    "\n",
    "        state = next_state\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    return data, reward_sum\n",
    "\n",
    "\n",
    "play()[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 8.242215468858047\n",
      "10 56.35281865833723\n",
      "20 89.26041464918394\n",
      "30 79.48376752327896\n",
      "40 86.18275362791601\n",
      "50 100.60658900020027\n",
      "60 142.7949989728566\n",
      "70 152.98101699084484\n",
      "80 127.25442772215243\n",
      "90 135.99060925087448\n",
      "100 155.7190656614068\n",
      "110 140.29182931592254\n",
      "120 144.72175596559498\n",
      "130 126.16712141465396\n",
      "140 131.33781911515845\n",
      "150 179.60470230260688\n",
      "160 178.77780327664752\n",
      "170 179.13341488648996\n",
      "180 179.33675428908919\n",
      "190 179.39080509646152\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    for epoch in range(200):\n",
    "        for i, (state, action, reward, next_state, over) in enumerate(loader):\n",
    "            cql.train_value(state, action, reward, next_state, over)\n",
    "            cql.train_action(state)\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            test_result = sum([play()[-1] for _ in range(20)]) / 20\n",
    "            print(epoch, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcpElEQVR4nO3dfWxb9b0/8LcfYufxOCRt7OY2oZlaKFEfGGmbekhjvzVrYBmjI/zUoYoF6A9E51ZtM1UjGw0ampT+ijQGW2mni0bRlSBT0QqjtLDctARQ3afQjPSBDO7tSH5N7fSB2GnS2I79+f1Bc27dpsVuY3/t5v2SjkTO+Th+m81vjs+DYxARARFRkhlVByCiiYnlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESigrn02bNmHatGnIzMxEZWUlDhw4oCoKESmgpHz+8pe/oL6+Hs8++yw++eQTzJ07F9XV1ejr61MRh4gUMKi4sbSyshLz58/HH//4RwBAJBJBSUkJVq1ahaeffvobHx+JRNDb24u8vDwYDIZExyWiGIkIBgYGUFxcDKPx2vs25iRl0gWDQbS3t6OhoUFfZzQaUVVVBbfbPeZjAoEAAoGA/vPJkydRXl6e8KxEdH16enowderUa84kvXzOnDmDcDgMu90etd5ut+Ozzz4b8zFNTU34zW9+c8X6np4eaJqWkJxEFD+/34+SkhLk5eV942zSy+d6NDQ0oL6+Xv959AVqmsbyIUpBsRwOSXr5TJo0CSaTCV6vN2q91+uFw+EY8zFWqxVWqzUZ8YgoSZJ+tstisaCiogKtra36ukgkgtbWVjidzmTHISJFlHzsqq+vR11dHebNm4cFCxbg97//PQYHB/HYY4+piENECigpn6VLl+L06dNobGyEx+PBnXfeiffee++Kg9BEdPNScp3PjfL7/bDZbPD5fDzgTJRC4nlv8t4uIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhIibjL58MPP8T999+P4uJiGAwGvPXWW1HbRQSNjY2YMmUKsrKyUFVVhc8//zxq5ty5c1i2bBk0TUN+fj6WL1+O8+fP39ALIaL0Enf5DA4OYu7cudi0adOY2zdu3IiXXnoJW7Zswf79+5GTk4Pq6moMDw/rM8uWLcPRo0fR0tKCHTt24MMPP8STTz55/a+CiNKP3AAAsn37dv3nSCQiDodDnn/+eX1df3+/WK1WeeONN0RE5NixYwJADh48qM/s2rVLDAaDnDx5Mqbn9fl8AkB8Pt+NxCeicRbPe3Ncj/mcOHECHo8HVVVV+jqbzYbKykq43W4AgNvtRn5+PubNm6fPVFVVwWg0Yv/+/WP+3kAgAL/fH7UQUXob1/LxeDwAALvdHrXebrfr2zweD4qKiqK2m81mFBQU6DOXa2pqgs1m05eSkpLxjE1ECqTF2a6Ghgb4fD596enpUR2JiG7QuJaPw+EAAHi93qj1Xq9X3+ZwONDX1xe1fWRkBOfOndNnLme1WqFpWtRCROltXMunrKwMDocDra2t+jq/34/9+/fD6XQCAJxOJ/r7+9He3q7P7N69G5FIBJWVleMZh4hSmDneB5w/fx5ffPGF/vOJEyfQ0dGBgoIClJaWYs2aNfjtb3+LGTNmoKysDOvXr0dxcTGWLFkCALjjjjtw77334oknnsCWLVsQCoWwcuVK/PSnP0VxcfG4vTAiSnHxnkrbs2ePALhiqaurE5GvT7evX79e7Ha7WK1WWbRokXR1dUX9jrNnz8rDDz8subm5ommaPPbYYzIwMBBzBp5qJ0pN8bw3DSIiCrvvuvj9fthsNvh8Ph7/IUoh8bw30+JsFxHdfFg+RKQEy4eIlIj7bBdRPEYPKYa++grBvj7IyAgy8vNhcThgMJlgMBgUJyRVWD6UUJGhIfTt3Imzra0Inj4NCYdhzstD3pw5mLJ0KTJLSlhAExTLhxImPDSEnldewdk9e4BIRF8/4vPhq48+woUvv8S0tWuR/a1vsYAmIB7zoYQQEfS9++4VxXOp4e5u/L9//3eEBweTnI5SAcuHEiJ07hzOvP/+VYtn1GBXF3yHDiUpFaUSlg8lRHhoCKGzZ79xTsJhBE6dQhpe60o3iOVDCRE6cybmQhnx+79xD4luPiwfSoiv9u6NuVCGT56EhMMJTkSphuVDCWEwxv5/reHubsjISALTUCpi+VBC5Nx+OxBHAYHHfCYclg8lhKWoCIjx2h2JRBC+5E8r0cTA8qGEMGsaYr1sUEKhrw8604TC8qHEiOOK5fCFC7jw5ZcJDEOpiOVDCWHMyIApJye24UgE4aEhXuszwbB8KCHM+fnIjOPvq0kolMA0lIpYPpQQRosFptzcmOdD/f084zXBsHwoIQxGIwwmU8zzQ198waucJxiWDyWMMTMz5tlgHLdj0M2B5UMJkzdrVuzDIvzYNcGwfChhLIWFMc9KKITw0FAC01CqYflQwhizs2OeHRkcRMDrTWAaSjUsH0oYg9EY8/1dkaEhBPv6EpyIUgnLhxImo7AQ1qKiuB7Dg84TB8uHEsaUnR3XtT7hCxcSmIZSDcuHEsZotcZ1uj107lwC01CqYflQYsVxg+n5o0cTGIRSDcuHEspSUBDzbHhwkNf6TCAsH0qonDvuiHlWIhEIb7GYMFg+lFAZcez5RIaHEeGFhhMGy4cSypiREfNs6Nw5Xmg4gbB8KGEMBgOMFgsMZnNM85HhYf7p5AmE5UMJZZ0yJa6PXhDhhYYTRFzl09TUhPnz5yMvLw9FRUVYsmQJurq6omaGh4fhcrlQWFiI3Nxc1NbWwnvZrnR3dzdqamqQnZ2NoqIirFu3DiP8u003JVNODoxWa8zz/CL5iSOu8mlra4PL5cK+ffvQ0tKCUCiExYsXY/CSXeW1a9finXfewbZt29DW1obe3l48+OCD+vZwOIyamhoEg0Hs3bsXr732GrZu3YrGxsbxe1WUMowWCwxxHPcJnD6dwDSUSgxyA/u4p0+fRlFREdra2vDd734XPp8PkydPxuuvv46HHnoIAPDZZ5/hjjvugNvtxsKFC7Fr1y786Ec/Qm9vL+x2OwBgy5Yt+OUvf4nTp0/DYrF84/P6/X7YbDb4fD5omna98SkJJBLB8fp6XPjv/45pXquowPTGRhjiuDiRUkc8780bOubj8/kAAAUXP9O3t7cjFAqhqqpKn5k5cyZKS0vhdrsBAG63G7Nnz9aLBwCqq6vh9/tx9CpXuAYCAfj9/qiF0kdWaWnMs/yb7RPHdZdPJBLBmjVrcPfdd2PWxW+s83g8sFgsyM/Pj5q12+3weDz6zKXFM7p9dNtYmpqaYLPZ9KUkjr+KQIoZDMj+1rdiHpdQiAU0QVx3+bhcLhw5cgTNzc3jmWdMDQ0N8Pl8+tLT05Pw56TxY7bZYp4NDw4iwrvbJ4TrKp+VK1dix44d2LNnD6ZOnaqvdzgcCAaD6O/vj5r3er1wOBz6zOVnv0Z/Hp25nNVqhaZpUQulB4PBEPMXigHAcG8vgmfOJDARpYq4ykdEsHLlSmzfvh27d+9GWVlZ1PaKigpkZGSgtbVVX9fV1YXu7m44nU4AgNPpRGdnJ/ou+da6lpYWaJqG8vLyG3ktlKLMmgZDDCcSgIsfu3jZxYQQ26WnF7lcLrz++ut4++23kZeXpx+jsdlsyMrKgs1mw/Lly1FfX4+CggJomoZVq1bB6XRi4cKFAIDFixejvLwcjzzyCDZu3AiPx4NnnnkGLpcL1jiuB6H0kVlcDFNODkaCwZjmecxnYoirfDZv3gwA+N73vhe1/tVXX8Wjjz4KAHjhhRdgNBpRW1uLQCCA6upqvPzyy/qsyWTCjh07sGLFCjidTuTk5KCurg7PPffcjb0SSlnmvLz47vH66qsEpqFUEVf5xHJJUGZmJjZt2oRNmzZddebWW2/Fzp0743lqSmMGs/nrL5OPhQgCvb2JDUQpgfd2UXLEcdHg0H/9VwKDUKpg+VDCGUwm5MycGddjeHPpzY/lQ4lnNMJ6lcsoxhK+cIFnvCYAlg8lhSknJ+bZkf5+RAKBBKahVMDyoYSL9ybRCz09/GqNCYDlQ0mR+W//FvOFhohE+FcsJgCWDyWF1eGAMdbyAfixawJg+VBSmPPyYDCZYpoVEQTPnk1wIlKN5UPJYTTGfq1POIzznZ2JzUPKsXwoKQwmEyyFhTHPR2K8D4zSF8uHksKQkYHMS77RUEQQCIdxZngYw2PcSCrhMC80vMnFdW8X0fUymEzIuPilYiKC7sFB9A4NId9igcVoROZlx4NGzp+HjIzE9eXzlF6450NJY7BYICLoGRxEfzCI+ZMmwZGVhX/6/fjc70fkkj2doNfLj143Oe75UFIYDAYYjEaEIhH0Dg1h3qRJODk0hPWHD6PL50OO2Yz/c9ttWFpWBiO+/kbDSCAAxHFlNKUXlg8lTc5tt+G8wYB8iwVGgwH/t7MTxy5+5a4/FMIfjx9Huc2GERFUZGerDUsJx49dlDSWyZNhMJlguvjdPv5QKGp7MBLB2WAQ+06fxjBvLL3psXwoaSxFRdDmzkVYBAYA/8vhgPmSa39u0zTM1DTcVViIvNzc2L+AjNISP3ZR0piysnD7o4/isNuNsAjqpk9HXkYG/vPUKUzJysITt92GkAimZGWh4J57YOZfKbmp8T8tlFTatGmY9/jjODY4CAD439OmYYvTiWfvvBNZZjM8Q0OYuWAB7EuWcM/nJsc9H0oqg9GIbz/0ECIADv/HfyBrcBAGEQQiEYRNJtxz332Y/uSTsFz8E9x082L5UNIZMzIwb+lS3Hb33eg9eBCBU6dgzcnBlDlzkFdeDlNWluqIlAQsH1LCYDTCVloK2yW3XNDEwg/VRKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEiJuMpn8+bNmDNnDjRNg6ZpcDqd2LVrl759eHgYLpcLhYWFyM3NRW1tLbxeb9Tv6O7uRk1NDbKzs1FUVIR169ZhhF8WTjThxFU+U6dOxYYNG9De3o5Dhw7h+9//Ph544AEcPXoUALB27Vq888472LZtG9ra2tDb24sHH3xQf3w4HEZNTQ2CwSD27t2L1157DVu3bkVjY+P4vioiSn1yg2655RZ55ZVXpL+/XzIyMmTbtm36tuPHjwsAcbvdIiKyc+dOMRqN4vF49JnNmzeLpmkSCASu+hzDw8Pi8/n0paenRwCIz+e70fhENI58Pl/M783rPuYTDofR3NyMwcFBOJ1OtLe3IxQKoaqqSp+ZOXMmSktL4Xa7AQButxuzZ8+G3W7XZ6qrq+H3+/W9p7E0NTXBZrPpS0lJyfXGJqIUEXf5dHZ2Ijc3F1arFU899RS2b9+O8vJyeDweWCwW5OfnR83b7XZ4PB4AgMfjiSqe0e2j266moaEBPp9PX3p6euKNTUQpJu7vcL799tvR0dEBn8+HN998E3V1dWhra0tENp3VaoXVak3ocxBRcsVdPhaLBdOnTwcAVFRU4ODBg3jxxRexdOlSBINB9Pf3R+39eL1eOBwOAIDD4cCBAweift/o2bDRGSKaGG74Op9IJIJAIICKigpkZGSgtbVV39bV1YXu7m44nU4AgNPpRGdnJ/r6+vSZlpYWaJqG8vLyG41CRGkkrj2fhoYG3HfffSgtLcXAwABef/11fPDBB3j//fdhs9mwfPly1NfXo6CgAJqmYdWqVXA6nVi4cCEAYPHixSgvL8cjjzyCjRs3wuPx4JlnnoHL5eLHKqIJJq7y6evrw89+9jOcOnUKNpsNc+bMwfvvv48f/OAHAIAXXngBRqMRtbW1CAQCqK6uxssvv6w/3mQyYceOHVixYgWcTidycnJQV1eH5557bnxfFRGlPIOIiOoQ8fL7/bDZbPD5fNA0TXUcIroonvcm7+0iIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUuKHy2bBhAwwGA9asWaOvGx4ehsvlQmFhIXJzc1FbWwuv1xv1uO7ubtTU1CA7OxtFRUVYt24dRkZGbiQKEaWZ6y6fgwcP4k9/+hPmzJkTtX7t2rV45513sG3bNrS1taG3txcPPvigvj0cDqOmpgbBYBB79+7Fa6+9hq1bt6KxsfH6XwURpR+5DgMDAzJjxgxpaWmRe+65R1avXi0iIv39/ZKRkSHbtm3TZ48fPy4AxO12i4jIzp07xWg0isfj0Wc2b94smqZJIBAY8/mGh4fF5/PpS09PjwAQn893PfGJKEF8Pl/M783r2vNxuVyoqalBVVVV1Pr29naEQqGo9TNnzkRpaSncbjcAwO12Y/bs2bDb7fpMdXU1/H4/jh49OubzNTU1wWaz6UtJScn1xCaiFBJ3+TQ3N+OTTz5BU1PTFds8Hg8sFgvy8/Oj1tvtdng8Hn3m0uIZ3T66bSwNDQ3w+Xz60tPTE29sIkox5niGe3p6sHr1arS0tCAzMzNRma5gtVphtVqT9nxElHhx7fm0t7ejr68Pd911F8xmM8xmM9ra2vDSSy/BbDbDbrcjGAyiv78/6nFerxcOhwMA4HA4rjj7Nfrz6AwR3fziKp9Fixahs7MTHR0d+jJv3jwsW7ZM/+eMjAy0trbqj+nq6kJ3dzecTicAwOl0orOzE319ffpMS0sLNE1DeXn5OL0sIkp1cX3sysvLw6xZs6LW5eTkoLCwUF+/fPly1NfXo6CgAJqmYdWqVXA6nVi4cCEAYPHixSgvL8cjjzyCjRs3wuPx4JlnnoHL5eJHK6IJJK7yicULL7wAo9GI2tpaBAIBVFdX4+WXX9a3m0wm7NixAytWrIDT6UROTg7q6urw3HPPjXcUIkphBhER1SHi5ff7YbPZ4PP5oGma6jhEdFE8703e20VESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKSEWXWA6yEiAAC/3684CRFdavQ9OfoevZa0LJ+zZ88CAEpKShQnIaKxDAwMwGazXXMmLcunoKAAANDd3f2NLzDV+P1+lJSUoKenB5qmqY4TM+ZOrnTNLSIYGBhAcXHxN86mZfkYjV8fqrLZbGn1P8ylNE1Ly+zMnVzpmDvWHQIecCYiJVg+RKREWpaP1WrFs88+C6vVqjpK3NI1O3MnV7rmjodBYjknRkQ0ztJyz4eI0h/Lh4iUYPkQkRIsHyJSguVDREqkZfls2rQJ06ZNQ2ZmJiorK3HgwAGleT788EPcf//9KC4uhsFgwFtvvRW1XUTQ2NiIKVOmICsrC1VVVfj888+jZs6dO4dly5ZB0zTk5+dj+fLlOH/+fEJzNzU1Yf78+cjLy0NRURGWLFmCrq6uqJnh4WG4XC4UFhYiNzcXtbW18Hq9UTPd3d2oqalBdnY2ioqKsG7dOoyMjCQs9+bNmzFnzhz96l+n04ldu3aldOaxbNiwAQaDAWvWrEm77ONC0kxzc7NYLBb585//LEePHpUnnnhC8vPzxev1Ksu0c+dO+fWvfy1//etfBYBs3749avuGDRvEZrPJW2+9Jf/4xz/kxz/+sZSVlcmFCxf0mXvvvVfmzp0r+/btk48++kimT58uDz/8cEJzV1dXy6uvvipHjhyRjo4O+eEPfyilpaVy/vx5feapp56SkpISaW1tlUOHDsnChQvlO9/5jr59ZGREZs2aJVVVVXL48GHZuXOnTJo0SRoaGhKW+29/+5u8++678s9//lO6urrkV7/6lWRkZMiRI0dSNvPlDhw4INOmTZM5c+bI6tWr9fXpkH28pF35LFiwQFwul/5zOByW4uJiaWpqUpjqf1xePpFIRBwOhzz//PP6uv7+frFarfLGG2+IiMixY8cEgBw8eFCf2bVrlxgMBjl58mTSsvf19QkAaWtr03NmZGTItm3b9Jnjx48LAHG73SLydfEajUbxeDz6zObNm0XTNAkEAknLfsstt8grr7ySFpkHBgZkxowZ0tLSIvfcc49ePumQfTyl1ceuYDCI9vZ2VFVV6euMRiOqqqrgdrsVJru6EydOwOPxRGW22WyorKzUM7vdbuTn52PevHn6TFVVFYxGI/bv35+0rD6fD8D/fGtAe3s7QqFQVPaZM2eitLQ0Kvvs2bNht9v1merqavj9fhw9ejThmcPhMJqbmzE4OAin05kWmV0uF2pqaqIyAunx73s8pdVd7WfOnEE4HI76Fw8Adrsdn332maJU1+bxeABgzMyj2zweD4qKiqK2m81mFBQU6DOJFolEsGbNGtx9992YNWuWnstisSA/P/+a2cd6baPbEqWzsxNOpxPDw8PIzc3F9u3bUV5ejo6OjpTNDADNzc345JNPcPDgwSu2pfK/70RIq/KhxHG5XDhy5Ag+/vhj1VFicvvtt6OjowM+nw9vvvkm6urq0NbWpjrWNfX09GD16tVoaWlBZmam6jjKpdXHrkmTJsFkMl1x9N/r9cLhcChKdW2jua6V2eFwoK+vL2r7yMgIzp07l5TXtXLlSuzYsQN79uzB1KlT9fUOhwPBYBD9/f3XzD7WaxvdligWiwXTp09HRUUFmpqaMHfuXLz44ospnbm9vR19fX246667YDabYTab0dbWhpdeeglmsxl2uz1lsydCWpWPxWJBRUUFWltb9XWRSAStra1wOp0Kk11dWVkZHA5HVGa/34/9+/frmZ1OJ/r7+9He3q7P7N69G5FIBJWVlQnLJiJYuXIltm/fjt27d6OsrCxqe0VFBTIyMqKyd3V1obu7Oyp7Z2dnVHm2tLRA0zSUl5cnLPvlIpEIAoFASmdetGgROjs70dHRoS/z5s3DsmXL9H9O1ewJofqId7yam5vFarXK1q1b5dixY/Lkk09Kfn5+1NH/ZBsYGJDDhw/L4cOHBYD87ne/k8OHD8uXX34pIl+fas/Pz5e3335bPv30U3nggQfGPNX+7W9/W/bv3y8ff/yxzJgxI+Gn2lesWCE2m00++OADOXXqlL4MDQ3pM0899ZSUlpbK7t275dChQ+J0OsXpdOrbR0/9Ll68WDo6OuS9996TyZMnJ/TU79NPPy1tbW1y4sQJ+fTTT+Xpp58Wg8Egf//731M289VcerYr3bLfqLQrHxGRP/zhD1JaWioWi0UWLFgg+/btU5pnz549AuCKpa6uTkS+Pt2+fv16sdvtYrVaZdGiRdLV1RX1O86ePSsPP/yw5ObmiqZp8thjj8nAwEBCc4+VGYC8+uqr+syFCxfk5z//udxyyy2SnZ0tP/nJT+TUqVNRv+df//qX3HfffZKVlSWTJk2SX/ziFxIKhRKW+/HHH5dbb71VLBaLTJ48WRYtWqQXT6pmvprLyyedst8ofp8PESmRVsd8iOjmwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpMT/B92epTKeTbOUAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "159.71679134439796"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
